
# Load data ---------------------------------------------------------------

col_combine <- c('drp_zweit_sh', 'npd_zweit_sh', 'rep_zweit_sh', 'afd_zweit_sh')    
btw_main <- btw_kreis_harm_53_17 %>%                                          # Combine certain columns
  dplyr::mutate(rrp_zweit_sh = invoke(coalesce, across(all_of(col_combine)))) %>%
  dplyr::select(rrp_zweit_sh, colnames(btw_kreis_harm_53_17)[! colnames(btw_kreis_harm_53_17) %in% col_combine])


btw_main %<>% 
  mutate(
    parl_treat_fct = case_when(
    greens_in_st_parl == 1 &
      (greens_first_entry_st_parl == 1 |
         greens_any_entry_st_parl == 1) &
      greens_exit_st_parl == 0 ~ "Entry",
    (greens_in_st_parl == 1 | greens_in_st_parl == 0) &
      greens_any_entry_st_parl == 0 &
      greens_exit_st_parl == 0 ~ "No Change",
    (greens_in_st_parl == 0) &
      greens_any_entry_st_parl == 0 & greens_exit_st_parl == 1 ~ "Exit"),
    across(rrp_zweit_sh, ~ dplyr::lag(.x, n = 1L), .names = "{.col}_l1"),
    treatment_gs = case_when(
      state_id == 1 ~ 13,
      state_id == 2 ~ 9,
      state_id == 3 ~ 9,
      state_id == 4 ~ 8,
      state_id == 5 ~ 11,
      state_id == 6 ~ 9,
      state_id == 7 ~ 10,
      state_id == 8 ~ 8,
      state_id == 9 ~ 10,
      state_id == 10 ~ 13,
      state_id == 11 ~ 9,
      state_id == 12 ~ 11,
      state_id == 13 ~ 11,
      state_id == 14 ~ 11,
      state_id == 15 ~ 11,
      state_id == 16 ~ 11
    ),
    across(c(parl_treat_fct), factor),
    across(c(turnout:cdu_csu_zweit_sh, afd_zweit_sh_l1:delta_cdu_csu), ~ .x  *100),
    pop_dens = tot_pop / area_km2,
    cluster_id = as.numeric(state_id),
    ags_num = as.numeric(ags_proj),
    time_id = as.integer(factor(year)),
    parl_treat_entry = case_when(parl_treat_fct == "Entry" ~ "Entry",
                                 T ~ 'No change'),
    parl_treat_entry = factor(parl_treat_entry)
  ) %>%  
  group_by(ags_proj) %>% 
  mutate(across(c(turnout, cdu_csu_zweit_sh, rrp_zweit_sh, fdp_zweit_sh, die_linke_zweit_sh, spd_zweit_sh), function (x) {x - dplyr::lead(x)}, .names = "delta_{col}")) %>% 
  filter(year >= 1969) %>% 
  ungroup()

btw_west <- btw_main %>% filter(state_id < 11) %>% as.data.frame()
btw_east <- btw_main %>% filter(state_id >= 11) %>% as.data.frame()



# First Difference --------------------------------------------------------

# run no boot
parl_fd_agsfe_mods <- feols(
    delta_cdu_csu ~ i(parl_treat_fct, ref = "No Change") + sw0(
        turnout + pop_dens + emp_sh + log_gdp_capita_pps + gdp_growth_rt_perc + agri_emp_sh + manu_gva_sh
    ) |
        sw(
            ags_num,
            ags_num + year,
            ags_num + year + ags_num[year],
            state_id + ags_num + year
        ),
    data = drop_na(btw_main, delta_cdu_csu)
    )

parl_fd_statefe_mods <- feols(
    delta_cdu_csu ~ i(parl_treat_fct, ref = "No Change") + sw0(
        turnout + pop_dens + emp_sh + gdp_capita_pps + gdp_growth_rt_perc + agri_emp_sh + manu_gva_sh
    ) |
        sw(
            state_id,
            state_id + year,
            state_id + year + state_id[year],
            state_id + ags_num + year
        ),
    data = drop_na(btw_main, delta_cdu_csu)
)

# calculate bootstrap cluster SEs. 
clus_boot_mods <-
    map2(
        .x = list(parl_fd_agsfe_mods, parl_fd_statefe_mods),
        .y = c("ags_num", "state_id"),
         function (.x, .y)
             map(.x, function(.y)
                 summary(
                     .x,
                     vcov = sandwich::vcovBS,
                     cluster = .y,
                     cores = 6,
                     R = 1000
                 ))
)




# Create Latex output table
options(knitr.table.format = "latex")

# Define how to display these
fd_mod_agsse_ls <- as.list(clus_boot_mods[[1]][[1]])
fd_mod_statese_ls <- as.list(clus_boot_mods[[2]][[1]])

names(fd_mod_agsse_ls) <- map_chr(1:8, ~ paste("FD", .x))
names(fd_mod_statese_ls) <- map_chr(1:8, ~ paste("FD", .x))

cn <- c('parl_treat_fct::Entry' = 'Greens Enter')
gofm <- tribble(~raw, ~clean, ~fmt,
                "nobs", "N", 0,
                "r.squared", "R2", 2,
                "within.r.squared", "R2 Within", 2,
                "FE: year", "FE: Year", 0,
                "FE: state_id", "FE: Land", 0,
                "FE: ags_num", "FE: Kreis", 0,
             )
rows <- tribble(~term, ~ "FD 1", ~ "FD 2", ~ "FD 3", ~ "FD 4", ~ "FD 5", ~ "FD 6", ~ "FD 7", ~ "FD 8",
                "Covs. Included", "No", "Yes", "No", "Yes", "No", "Yes", "No", "Yes",
                "Varying Slopes", "No", "No", "No", "No", "Yes", "Yes", "No", "No",
                "Std. Errors", "Cluster BS", "Cluster BS", "Cluster BS", "Cluster BS", "Cluster BS", "Cluster BS", "Cluster BS", "Cluster BS" )

attr(rows, 'position') <- c(7, 8,9)

mods_ls <- list(fd_mod_agsse_ls, fd_mod_statese_ls)
tab_ls <- map(
    mods_ls,
    ~ msummary(
        .x,
        align = "lcccccccc",
        output = "kableExtra",
        fmt = 2,
        estimate  = "{estimate}",
        statistic = "conf.int",
        coef_map = cn,
        stars = F,
        gof_map  = gofm,
        add_rows = rows,
        escape = FALSE,
        booktabs = T,
        format = "latex",
        linesep = "",
        
    ) %>%
        kable_styling(
            position = "center",
            latex_options = c("striped", "hold_position", "scale_down"),
        ) %>%
        add_header_above(c(
            " " = 1, "DV: CDU/CSU Vote (%)" = 8
        )) %>%
        footnote(
            number = c(
                "Numbers in square brackets represent bootstrapped 95% confidence intervals; "
            )
        ) 
)

files <- c("output/tables/table_b2.tex","output/tables/table_b3.tex" )
map2(tab_ls,
     files,
     ~ kableExtra::save_kable(.x, keep_tex = T, file = .y))


# # Turnout analysis --------------------------------------------------------

turnout_mods <- feols(
    delta_turnout ~ i(parl_treat_fct, ref = "No Change") 
    | sw(state_id + year, state_id + year + state_id[year]),
    data = drop_na(btw_main, delta_turnout)
)

# calculate bootstrap cluster SEs.
clus_boot_mods <-
    map(turnout_mods, function(.x)
        summary(
            .x,
            vcov = sandwich::vcovBS,
            cluster = "state_id",
            cores = 6,
            R = 1000
        )
    )


# fit
models <- list(
    "regular FE" = clus_boot_mods[[1]],
    "varying slopes" = clus_boot_mods[[2]]
)

# summarize
dat <- map_dfr(c(.9, .95), function(x) {
    modelplot(models, conf_level = x, draw = FALSE) %>%
        mutate(.width = x,
               term = case_when(
                   term == "parl_treat_fct::Entry" ~ "Greens Entry",
                   T ~ as.character(term)))
}
)

# plot entry
plot_turnout <- dat %>%
    filter(term %in% c("Greens Entry")) %>%
    mutate(
        term = factor(term),
        term = fct_relevel(term, "Greens Entry", after = Inf))  %>%
    ggplot(., aes(
        y = term, x = estimate,
        xmin = conf.low, xmax = conf.high,
        colour = model)) +
    ggdist::geom_pointinterval(
        position = position_dodge(width = 0.25),
        interval_size_range = c(1.5, 2.5),
        fatten_point = .9,
        point_fill = "white") +
    geom_vline(xintercept = 0, linetype = "dashed") +
    scale_color_grey() +
    xlab("Effect of Green entry into state parliament on change of turnout.") +
    ylab("") +
    theme_hanno()+
    theme(
        axis.text = element_text(size = 12),
        axis.title = element_text(size = 14),
        legend.title = element_blank(),
        legend.position = c(0.98, 0.02),  # Adjust to position it inside the plot
        legend.justification = c("right", "bottom"),
        legend.box = "horizontal",       # Arrange the legends side by side
        legend.direction = "horizontal", # Horizontal direction for legends
        legend.box.just = "right",
        legend.margin = margin(5, 5, 5, 5),
        legend.spacing.x = unit(0.4, 'cm') # Adjust the spacing between legends
    )
plot_turnout
ggsave(
    plot = plot_turnout,
    width = 10,
    height = 3,
    device = "png",
    filename = "output/figures/fig_c7.png",
    dpi = 600
)

# First Difference --------------------------------------------------------


parl_fd_statefe_mods <- feols(
    c(delta_fdp_zweit_sh, delta_rrp_zweit_sh, delta_spd_zweit_sh) ~ i(parl_treat_fct, ref = "No Change") + turnout
    | sw(state_id + year, state_id + year + state_id[year]),
    data = btw_west
)

# calculate bootstrap cluster SEs.
clus_boot_mods <-
    map(parl_fd_statefe_mods, function(.x)
        summary(
            .x,
            vcov = sandwich::vcovBS,
            cluster = ~state_id,
            cores = 6,
            R = 1000
        )
    )



# Create Coefplot ---------------------------------------------------------

# fit
models <- list(
    "FDP: regular FE" = clus_boot_mods[[1]],
    "FDP: varying slopes" = clus_boot_mods[[4]],
    "SPD: regular FE" = clus_boot_mods[[3]],
    "SPD: varying slopes" = clus_boot_mods[[6]],
    "RRP: regular FE" = clus_boot_mods[[2]],
    "RRP: varying slopes" = clus_boot_mods[[5]]
)

# summarize
dat <- map_dfr(c(.9, .95), function(x) {
    modelplot(models, conf_level = x, draw = FALSE) %>%
        mutate(.width = x,
               term = case_when(
                   term == "parl_treat_fct::Entry" ~ "Greens Entry",
                   T ~ as.character(term)))
}
)

# plot entry
plot_other_parties <- dat %>%
    filter(term %in% c("Greens Entry")) %>%
    mutate(
        party = str_sub(model, end = 3L),
        model = str_sub(model, start =6L),
        term = factor(term),
        term = fct_relevel(term, "Greens Entry", after = Inf))  %>%
    ggplot(., aes(
        y = term, x = estimate,
        xmin = conf.low, xmax = conf.high,
        colour = party, alpha =model)) +
    ggdist::geom_pointinterval(
        position = position_dodge(width = 0.25),
        interval_size_range = c(1.5, 2.5),
        fatten_point = .9,
        point_fill = "white") +
    geom_vline(xintercept = 0, linetype = "dashed") +
    scale_alpha_manual(values = c('regular FE' = 1, 'varying slopes' = 0.45)) +
    scale_color_manual(values = c('SPD' = 'red', 'FDP' = 'gold', 'RRP' = 'brown')) +
    xlab("Effect of Green entry into state parliament on change of party vote shares.") +
    ylab("") +
    theme_hanno() +
    theme(
        axis.text = element_text(size = 12),
        axis.title = element_text(size = 14),
        legend.title = element_blank(),
        legend.position = c(0.98, 0.02),  # Adjust to position it inside the plot
        legend.justification = c("right", "bottom"),
        legend.box = "horizontal",       # Arrange the legends side by side
        legend.direction = "horizontal", # Horizontal direction for legends
        legend.box.just = "right",
        legend.margin = margin(5, 5, 5, 5),
        legend.spacing.x = unit(0.4, 'cm') # Adjust the spacing between legends
    )
plot_other_parties
ggsave(
    plot = plot_other_parties,
    width = 10,
    height = 3,
    device = "png",
    filename = "output/figures/fig_c8.png",
    dpi = 600
)

